from math import pi
from ase.io import read
from ase.constraints import FixAtoms
from ase.calculators.siesta import Siesta
#from ase.calculators.siesta.parameters import Specie, PAOBasisBlock
from ase.units import *
from os import listdir
from numpy.linalg import norm
from numpy import loadtxt, savetxt
from ase.optimize import BFGS, FIRE
#from ase.neb import NEB 
from ase.dimer import DimerControl, MinModeAtoms, MinModeTranslate

#dimer params
force_tolerance = 0.014  # in eV/Ang dm_tolerance = 5.0e-7
dimer_dist = 0.005
guess_name = 'siesta.STRUCT_OUT'
prefix = 'dimer'
NNdist = 1.94 # only used if not restarting

# attempt a restart
restart = False
for name in listdir('.'):
    if name==('%s.traj'%prefix): restart = True

if restart:
    print "Found %s.traj, attempting to restart..."%prefix
    try:
        dimer = read('%s.traj@-1'%prefix)
        mode = loadtxt('%s.MODE'%prefix) 
    except Exception as err:
        print "ERROR: %s"%err
        raise err
    print 'Restarting from saved trajectory' 
else:
    dimer = read(guess_name)
    old_positions = dimer.positions.copy()
    mode_positions = dimer.positions.copy()
    # set initial NN distance
    r1 = dimer.positions[0]
    r2 = dimer.positions[1]
    cm = (r1 + r2)/2.0
    dcm = norm(r2-cm)
    r1 = cm + (NNdist/2.0)*(r1-cm)/dcm
    r2 = cm + (NNdist/2.0)*(r2-cm)/dcm
    mode_positions[0] = r1 #can also set dimer.positions[0]=r1; similar for r2
    mode_positions[1] = r2    
    mode = mode_positions - old_positions
    mode = mode/norm(mode.ravel())
# RAH FIXME: ^^^ Default ASE does not normalize the input mode(s)
#    (but custom edits do this normalization now)

dimer.set_pbc((True,True,True))

# fix some Ns
constraint = FixAtoms(mask=[atom.index <= 3 and atom.index >= 2 for atom in dimer])
dimer.set_constraint(constraint)

sc = Siesta(label=prefix,customSpeciesIndices=[1,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3])
#sc = Siesta(label=name)
sc.set_fdf('MeshCutoff',120.0 * Ry)
sc.set_fdf('SpinPolarized',False)
sc.set_fdf('WriteDM', True)
sc.set_fdf('DM.UseSaveDM', True)
#sc.set_fdf('FixSpin',False)
#sc.set_fdf('TotalSpin',5)
#sc.set_fdf('WriteMullikenPop',1)
sc.set_fdf('SystemName',prefix)
sc.set_fdf('SystemLabel',prefix)
sc.set_fdf('SCFMustConverge',False)
sc.set_fdf('MaxSCFIterations',400)
sc.set_fdf('NumberOfSpecies',3)
sc.set_fdf('Chemical_Species_Label',["""\
1 7 N
2 44 Ru_surf
3 44 Ru_bulk                                   """])
# sc.set_fdf('PAO.BasisSize', 'DZP')
sc.set_fdf('PAO.Basis',["""\
N     3     
 n=2    0    2  S .2305887
   5.2553248   0.0
   1.000   1.000
 n=2    1    2  S .2846094
   5.2374770   0.0
   1.000   1.000  
 n=3    2    1  
   2.5406787
   1.000  
Ru_surf    3                    # Species label, number of l-shells, effective charge
 n=5   0   2   E  5.1969765  6.4102033   # n, l, Nzeta, soft confinement (sc), sc potential, sc radius
   4.8150212    -.5408860   
   1.000      1.000   
 n=4   2   2   E  16.9399325  6.4404240  
   5.2358603    -.0750774  
   1.000      1.000  
 n=5   1   1   E  16.7403551  6.7658275
   7.4372332
   1.00
Ru_bulk    3                    # Species label, number of l-shells, effective charge
 n=5   0   2   E  5.1276155  6.2372922   # n, l, Nzeta, soft confinement (sc), sc potential, sc radius
   4.6336805    -.5019498   
   1.000      1.000   
 n=4   2   2   E  16.8979148  6.7782168  
   5.0620075    -.0825330  
   1.000      1.000  
 n=5   1   1   E  16.5510519  8.0000000
   5.2883443
   1.00                            """])
sc.set_fdf('DM.Tolerance', 1.0e-5)
sc.set_fdf('XC.Functional','GGA')
sc.set_fdf('XC.Authors','PBE')
sc.set_fdf('DM.MixingWeight',0.02)
sc.set_fdf('DM.NumberPulay',8)
sc.set_fdf('ElectronicTemperature',300.0*kB)
sc.set_fdf('DM.UseSaveDM',True)
sc.set_fdf('WriteMDXmol',True)    
sc.set_fdf('TD.Adiabatic',True)
sc.set_fdf('TD.ntime',0)
sc.set_fdf('TD.Fermi',True)
dimer.set_calculator(sc)

# DIMER SETUP
dimer_control = DimerControl(initial_eigenmode_method='displacement',
                             maximum_translation=0.10, 
                             dimer_separation=dimer_dist,
                             trial_trans_step=5.0e-3,
                             f_rot_max=0.2,
                             trial_angle=pi/4,
                             displacement_method='vector',
                             logfile='%s_control.log'%prefix,
                             eigenmode_logfile='%s_eigenmode.log'%prefix,
                             max_num_rot=2)

dimer_atoms = MinModeAtoms(dimer, dimer_control,eigenmodes=[mode])
#dimer_atoms.displace(displacement_vector = mode)

dimer_relax = MinModeTranslate(dimer_atoms, trajectory='%s.traj'%prefix, logfile='%s_relax.log'%prefix)
# also possible to use these:
#dimer_relax = FIRE(dimer_atoms, trajectory='dimer.traj',
#                   restart='restart_file',
#                   dt=0.1, # default 0.1
#                   maxmove=0.2, # default 0.2
#                   dtmax=1.0, # default 1.0
#                   Nmin=3, # default 5
#                   finc=1.1, # default 3
#                   fdec=0.71, # default 0.5
#                   astart=0.2, # default 0.1
#                   fa=0.99) # default 0.99
#dimer_relax = BFGS(dimer_atoms, trajectory='dimer.traj', restart='restart_file')

# RUN DIMER
dimer_relax.run(fmax=force_tolerance)

# clean up
savetxt('%s.MODE'%prefix,dimer_atoms.get_eigenmode())
